{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import DeepPurpose.models as models\n",
    "from DeepPurpose.utils import *\n",
    "from DeepPurpose.dataset import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Beginning Processing...\n",
      "Beginning to extract zip file...\n",
      "Done!\n",
      "in total: 118254 drug-target pairs\n",
      "encoding drug...\n",
      "unique drugs: 2068\n",
      "drug encoding finished...\n",
      "encoding protein...\n",
      "unique target sequence: 229\n",
      "protein encoding finished...\n",
      "splitting dataset...\n",
      "Done.\n",
      "cost about 195 seconds\n"
     ]
    }
   ],
   "source": [
    "from time import time\n",
    "\n",
    "t1 = time()\n",
    "X_drug, X_target, y = load_process_KIBA('./data/', binary=False)\n",
    "\n",
    "drug_encoding = 'MPNN'\n",
    "target_encoding = 'CNN'\n",
    "train, val, test = data_process(X_drug, X_target, y, \n",
    "                                drug_encoding, target_encoding, \n",
    "                                split_method='random',frac=[0.7,0.1,0.2])\n",
    "\n",
    "# use the parameters setting provided in the paper: https://arxiv.org/abs/1801.10193\n",
    "config = generate_config(drug_encoding = drug_encoding, \n",
    "                         target_encoding = target_encoding, \n",
    "                         cls_hidden_dims = [1024,1024,512], \n",
    "                         train_epoch = 100, \n",
    "                         test_every_X_epoch = 10, \n",
    "                         LR = 0.001, \n",
    "                         batch_size = 128,\n",
    "                         hidden_dim_drug = 128,\n",
    "                         mpnn_hidden_size = 128,\n",
    "                         mpnn_depth = 3, \n",
    "                         cnn_target_filters = [32,64,96],\n",
    "                         cnn_target_kernels = [4,8,12]\n",
    "                        )\n",
    "model = models.model_initialize(**config)\n",
    "t2 = time()\n",
    "print(\"cost about \" + str(int(t2-t1)) + \" seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use CPU/s!\n",
      "--- Data Preparation ---\n",
      "--- Go for Training ---\n",
      "Training at Epoch 1 iteration 0 with loss 139.439. Total time 0.00305 hours\n",
      "Training at Epoch 1 iteration 100 with loss 0.85528. Total time 0.20138 hours\n",
      "Training at Epoch 1 iteration 200 with loss 0.40343. Total time 0.34138 hours\n",
      "Training at Epoch 1 iteration 300 with loss 0.84512. Total time 0.47444 hours\n",
      "Training at Epoch 1 iteration 400 with loss 0.66289. Total time 0.60805 hours\n",
      "Training at Epoch 1 iteration 500 with loss 0.85238. Total time 0.74027 hours\n",
      "Training at Epoch 1 iteration 600 with loss 0.81808. Total time 0.87611 hours\n",
      "Validation at Epoch 1 , MSE: 0.72030 , Pearson Correlation: 0.43944 with p-value: 0.0 , Concordance Index: 0.68554\n",
      "Training at Epoch 2 iteration 0 with loss 1.07556. Total time 0.99277 hours\n",
      "Training at Epoch 2 iteration 100 with loss 0.93850. Total time 1.11694 hours\n",
      "Training at Epoch 2 iteration 200 with loss 0.50950. Total time 1.24111 hours\n",
      "Training at Epoch 2 iteration 300 with loss 0.65378. Total time 1.36888 hours\n",
      "Training at Epoch 2 iteration 400 with loss 0.53988. Total time 1.50416 hours\n",
      "Training at Epoch 2 iteration 500 with loss 0.88967. Total time 1.63472 hours\n",
      "Training at Epoch 2 iteration 600 with loss 0.75896. Total time 1.76277 hours\n",
      "Validation at Epoch 2 , MSE: 0.62829 , Pearson Correlation: 0.51199 with p-value: 0.0 , Concordance Index: 0.70848\n",
      "Training at Epoch 3 iteration 0 with loss 0.91195. Total time 1.87888 hours\n",
      "Training at Epoch 3 iteration 100 with loss 0.87829. Total time 2.01361 hours\n",
      "Training at Epoch 3 iteration 200 with loss 0.66226. Total time 2.14527 hours\n",
      "Training at Epoch 3 iteration 300 with loss 0.57633. Total time 2.28111 hours\n",
      "Training at Epoch 3 iteration 400 with loss 0.97758. Total time 2.4175 hours\n",
      "Training at Epoch 3 iteration 500 with loss 0.86774. Total time 2.55361 hours\n",
      "Training at Epoch 3 iteration 600 with loss 1.21729. Total time 2.69055 hours\n",
      "Validation at Epoch 3 , MSE: 0.50566 , Pearson Correlation: 0.53635 with p-value: 0.0 , Concordance Index: 0.71654\n",
      "Training at Epoch 4 iteration 0 with loss 0.62071. Total time 2.80777 hours\n",
      "Training at Epoch 4 iteration 100 with loss 0.52476. Total time 2.93527 hours\n",
      "Training at Epoch 4 iteration 200 with loss 0.67531. Total time 3.07222 hours\n",
      "Training at Epoch 4 iteration 300 with loss 0.55189. Total time 3.20916 hours\n",
      "Training at Epoch 4 iteration 400 with loss 0.68485. Total time 3.34583 hours\n",
      "Training at Epoch 4 iteration 500 with loss 0.44150. Total time 3.48277 hours\n",
      "Training at Epoch 4 iteration 600 with loss 0.47156. Total time 3.61916 hours\n",
      "Validation at Epoch 4 , MSE: 0.56851 , Pearson Correlation: 0.56797 with p-value: 0.0 , Concordance Index: 0.72315\n",
      "Training at Epoch 5 iteration 0 with loss 0.88668. Total time 3.73333 hours\n",
      "Training at Epoch 5 iteration 100 with loss 0.62687. Total time 3.865 hours\n",
      "Training at Epoch 5 iteration 200 with loss 0.50301. Total time 4.00138 hours\n",
      "Training at Epoch 5 iteration 300 with loss 0.81095. Total time 4.12805 hours\n",
      "Training at Epoch 5 iteration 400 with loss 0.67774. Total time 4.26472 hours\n",
      "Training at Epoch 5 iteration 500 with loss 0.86882. Total time 4.40138 hours\n",
      "Training at Epoch 5 iteration 600 with loss 0.49849. Total time 4.53777 hours\n",
      "Validation at Epoch 5 , MSE: 0.48221 , Pearson Correlation: 0.57899 with p-value: 0.0 , Concordance Index: 0.72680\n",
      "Training at Epoch 6 iteration 0 with loss 0.61511. Total time 4.65444 hours\n",
      "Training at Epoch 6 iteration 100 with loss 0.81947. Total time 4.79166 hours\n",
      "Training at Epoch 6 iteration 200 with loss 0.66717. Total time 4.92805 hours\n",
      "Training at Epoch 6 iteration 300 with loss 0.63587. Total time 5.06305 hours\n",
      "Training at Epoch 6 iteration 400 with loss 0.59426. Total time 5.19916 hours\n",
      "Training at Epoch 6 iteration 500 with loss 0.65938. Total time 5.33472 hours\n",
      "Training at Epoch 6 iteration 600 with loss 0.55365. Total time 5.47083 hours\n",
      "Validation at Epoch 6 , MSE: 0.45493 , Pearson Correlation: 0.59954 with p-value: 0.0 , Concordance Index: 0.73420\n",
      "Training at Epoch 7 iteration 0 with loss 0.45920. Total time 5.58888 hours\n",
      "Training at Epoch 7 iteration 100 with loss 0.29518. Total time 5.71361 hours\n",
      "Training at Epoch 7 iteration 200 with loss 0.65242. Total time 5.84916 hours\n",
      "Training at Epoch 7 iteration 300 with loss 0.70349. Total time 5.98722 hours\n",
      "Training at Epoch 7 iteration 400 with loss 0.62729. Total time 6.11111 hours\n",
      "Training at Epoch 7 iteration 500 with loss 0.50375. Total time 6.23583 hours\n",
      "Training at Epoch 7 iteration 600 with loss 0.36792. Total time 6.37027 hours\n",
      "Validation at Epoch 7 , MSE: 0.46791 , Pearson Correlation: 0.61566 with p-value: 0.0 , Concordance Index: 0.74448\n",
      "Training at Epoch 8 iteration 0 with loss 0.36232. Total time 6.4875 hours\n",
      "Training at Epoch 8 iteration 100 with loss 0.51948. Total time 6.62333 hours\n",
      "Training at Epoch 8 iteration 200 with loss 0.75429. Total time 6.7575 hours\n",
      "Training at Epoch 8 iteration 300 with loss 0.44889. Total time 6.89333 hours\n",
      "Training at Epoch 8 iteration 400 with loss 0.49060. Total time 7.02972 hours\n",
      "Training at Epoch 8 iteration 500 with loss 0.59408. Total time 7.15944 hours\n",
      "Training at Epoch 8 iteration 600 with loss 0.49257. Total time 7.29611 hours\n",
      "Validation at Epoch 8 , MSE: 0.42919 , Pearson Correlation: 0.63366 with p-value: 0.0 , Concordance Index: 0.74538\n",
      "Training at Epoch 9 iteration 0 with loss 0.49596. Total time 7.41361 hours\n",
      "Training at Epoch 9 iteration 100 with loss 0.80841. Total time 7.54583 hours\n",
      "Training at Epoch 9 iteration 200 with loss 0.50147. Total time 7.68305 hours\n",
      "Training at Epoch 9 iteration 300 with loss 0.61141. Total time 7.81694 hours\n",
      "Training at Epoch 9 iteration 400 with loss 0.58063. Total time 7.95277 hours\n",
      "Training at Epoch 9 iteration 500 with loss 0.48687. Total time 8.08083 hours\n",
      "Training at Epoch 9 iteration 600 with loss 0.64048. Total time 8.21472 hours\n",
      "Validation at Epoch 9 , MSE: 0.41064 , Pearson Correlation: 0.64931 with p-value: 0.0 , Concordance Index: 0.75442\n",
      "Training at Epoch 10 iteration 0 with loss 0.57446. Total time 8.33111 hours\n",
      "Training at Epoch 10 iteration 100 with loss 0.81588. Total time 8.46805 hours\n",
      "Training at Epoch 10 iteration 200 with loss 0.59722. Total time 8.605 hours\n",
      "Training at Epoch 10 iteration 300 with loss 0.52301. Total time 8.73444 hours\n",
      "Training at Epoch 10 iteration 400 with loss 0.58802. Total time 8.86055 hours\n",
      "Training at Epoch 10 iteration 500 with loss 0.38357. Total time 8.99666 hours\n",
      "Training at Epoch 10 iteration 600 with loss 0.35195. Total time 9.12972 hours\n",
      "Validation at Epoch 10 , MSE: 0.39347 , Pearson Correlation: 0.67301 with p-value: 0.0 , Concordance Index: 0.75873\n",
      "Training at Epoch 11 iteration 0 with loss 0.37873. Total time 9.24638 hours\n",
      "Training at Epoch 11 iteration 100 with loss 0.43624. Total time 9.37888 hours\n",
      "Training at Epoch 11 iteration 200 with loss 0.28653. Total time 9.50277 hours\n",
      "Training at Epoch 11 iteration 300 with loss 0.35587. Total time 9.6325 hours\n",
      "Training at Epoch 11 iteration 400 with loss 0.50246. Total time 9.76861 hours\n",
      "Training at Epoch 11 iteration 500 with loss 0.59975. Total time 9.90555 hours\n",
      "Training at Epoch 11 iteration 600 with loss 0.40750. Total time 10.0391 hours\n",
      "Validation at Epoch 11 , MSE: 0.38795 , Pearson Correlation: 0.69532 with p-value: 0.0 , Concordance Index: 0.76739\n",
      "--- Go for Testing ---\n",
      "Up to Epoch 10 Testing MSE: 0.48668658921466434 , Pearson Correlation: 0.6011887756161407 with p-value: 0.0 , Concordance Index: 0.7113144222432253\n",
      "Training at Epoch 12 iteration 0 with loss 0.46107. Total time 10.2927 hours\n",
      "Training at Epoch 12 iteration 100 with loss 0.40440. Total time 10.4205 hours\n",
      "Training at Epoch 12 iteration 200 with loss 0.40888. Total time 10.5561 hours\n",
      "Training at Epoch 12 iteration 300 with loss 0.45287. Total time 10.6855 hours\n"
     ]
    }
   ],
   "source": [
    "model.train(train, val, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_model('./model_MPNN_CNN_Kiba')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
